import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle as pkl
import torch
import random
from sklearn import metrics
from sklearn.metrics import classification_report
plt.rc('font',family='Times New Roman')
def set_random_seed(seed=0):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def data_loader(args):
    # load data
    feature_title = ['F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'F9', 'F10',
                     'F11', 'F12', 'F13', 'F14', 'F15', 'F16', 'F17', 'F18', 'F19', 'F20',
                     'F21', 'F22', 'F23', 'F24']
    # Data acquisition
    df_train = pkl.load(open(os.path.join(args.data_dir, args.train_set), 'rb'))
    df_test = pkl.load(open(os.path.join(args.data_dir, args.test_set), 'rb'))

    # Data preparation
    X_train = df_train[feature_title].astype(np.float32)
    X_test = df_test[feature_title].astype(np.float32)

    y_train = df_train['Label'].astype(np.int32)
    y_test = df_test['Label'].astype(np.int32)

    # Normalization
    min_num = X_train.min(axis=0)
    max_num = X_train.max(axis=0)
    X_train = (X_train - min_num) / (max_num - min_num)
    X_test = (X_test - min_num) / (max_num - min_num)
    return X_train, y_train, X_test, y_test


def model_evaluation(y_pred, y_test, args, now_time):
    nowname = ''.join([args.Model_name, now_time])

    print("---------- %s - Evaluation on Test Data ----------" % (args.Model_name))
    class_result = classification_report(np.ravel(y_test), y_pred, output_dict=True)
    print(classification_report(np.ravel(y_test), y_pred, digits=4))

    df = pd.DataFrame(class_result).transpose()
    df.to_csv("Results/%s.csv" % nowname, index=True)

    confusion_matrix = metrics.confusion_matrix(np.ravel(y_test), y_pred)
    C = confusion_matrix
    labels_name = ['Health', 'Fault 1', 'Fault 2', 'Fault 3']
    C = C.astype('float') / C.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(6, 6), dpi=300)
    plt.matshow(C, cmap=plt.cm.Blues, fignum=0)  # 根据最下面的图按自己需求更改颜色

    for i in range(len(C)):
        for j in range(len(C)):
            plt.annotate(round(C[j, i], 3), xy=(i, j), horizontalalignment='center',
                         verticalalignment='center', size=16)

    num_local = np.array(range(len(labels_name)))

    plt.xticks(num_local, labels_name, rotation=0, size=16)  # 将标签印在x轴坐标上
    plt.yticks(num_local, labels_name, fontproperties='Times New Roman', size=16)  # 将标签印在y轴坐标上

    plt.ylabel('True label', fontsize=16)  # 设置字体大小。
    plt.xlabel('Predicted label', fontsize=16)

    plt.savefig("Results/%s.png" % nowname)